import copy
import json
import bz2
from tqdm import tqdm
import argparse
import logging
import os
import random
import argparse
from rank_bm25 import BM25Okapi
import nltk
from nltk.corpus import stopwords
import string
from multiprocessing import Queue, Process, cpu_count, Manager
import time
import torch
import math
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer
try:
    nltk_stopwords = stopwords.words('english')
except:
    nltk.download('stopwords')
    nltk_stopwords = stopwords.words('english')
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
logging.getLogger().setLevel(logging.INFO)

def read_bz2(addr, return_length=False):
    f = []
    total_length = 0
    with bz2.open(addr) as reader:
        for line in tqdm(reader, desc="reading from {}".format(addr)):
            js_line = json.loads(line)
            total_length += len(js_line[1])
            f.append(js_line)
    f = dict(f)
    if return_length:
        return f, total_length
    else:
        return f

def save_new_file(new_e2c=None, new_e2c_f=None, new_e2p=None, new_e2p_f=None):
    if new_e2c is not None:
        with bz2.BZ2File(new_e2c_f, 'w') as writer:
            for item in tqdm(new_e2c.items(), desc='save file at ' + new_e2c_f):
                output_item = json.dumps(item, ensure_ascii=False) + "\n"
                writer.write(output_item.encode("utf-8"))
    if new_e2p is not None:
        with bz2.BZ2File(new_e2p_f, 'w') as writer:
            for item in tqdm(new_e2p.items(), desc='save file at ' + new_e2p_f):
                output_item = json.dumps(item, ensure_ascii=False) + "\n"
                writer.write(output_item.encode("utf-8"))

def filter_fre(args):
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2c_addr = os.path.join(args.file, 'processed', "e2c.bz2")
    e2c = read_bz2(e2c_addr)
    e2p = read_bz2(e2p_addr)
    logging.info("length of e2c: %d", len(e2c))
    logging.info("length of e2p: %d", len(e2p))
    rm_keys = []
    keep_keys = {}
    keys = list(e2c.keys())

    for key in keys:
        if len(e2c[key]) < args.fre:
            if len(e2c[key]) == args.fre - 1:
                rm_keys.append(key)
            del e2c[key]
        else:
            sampled_items = random.sample(e2c[key], k=args.k)
            e2c[key] = sampled_items
            keep_keys[key] = 1
            for item in sampled_items:
                keep_keys[item[0]] = 1
    logging.info("We remove entity anchors with frequency lower than %d, some examples are %s.", args.fre,
                 rm_keys[:100])
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + '{}-{}'.format(args.fre, args.k) + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + '{}-{}'.format(args.fre, args.k) + '.bz2')
    new_e2p = {}
    for key in keep_keys:
        new_e2p[key] = e2p.pop(key)
    save_new_file(e2c, e2c_save_dir, new_e2p, e2p_save_dir)

def filter_max(args):
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2c_addr = os.path.join(args.file, 'processed', "e2c.bz2")
    e2c = read_bz2(e2c_addr)
    e2p = read_bz2(e2p_addr)
    logging.info("length of e2c: %d", len(e2c))
    logging.info("length of e2p: %d", len(e2p))
    rm_keys = []
    keep_keys = {}
    keys = list(e2c.keys())
    total_count = sum([len(e2c[x]) for x in e2c])
    p = args.max / total_count
    for key in keys:
        items = e2c[key]
        sampled_items = [item for item in items if random.random() < p]
        if len(sampled_items) != 0:
            e2c[key] = sampled_items
            keep_keys[key] = 1
            for item in sampled_items:
                keep_keys[item[0]] = 1
        else:
            rm_keys.append(key)
            del e2c[key]
    logging.info(
        "Totally we have %d items. We remove entity anchors such that the maximum item number is around %d, some removed examples are %s.",
        total_count, args.max,
        rm_keys[:100])
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + '{}'.format(args.max) + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + '{}'.format(args.max) + '.bz2')
    new_e2p = {}
    for key in keep_keys:
        new_e2p[key] = e2p.pop(key)
    save_new_file(e2c, e2c_save_dir, new_e2p, e2p_save_dir)

def prepare_bm25_corpus(args, len_data, jobs_queue, output_queue, time_start):
    while True:
        i_t, target, documents = jobs_queue.get()
        if target is not None:
            if i_t % 100000 == 0:
                passed_time = time.time() - time_start
                estimate_time = passed_time / (i_t + 1) * (len_data - i_t)
                logging.info("prepare_bm25_corpus at No. %d entity '%s', time cost is %f. Estimated remaining time is %f", i_t, target, passed_time, estimate_time)
            documents_clean = []
            def_num = args.def_num
            definition = documents[: def_num]
            definition = [word for sen in definition for word in sen]  # linearize
            offset = 0
            if len(definition) < 30:
                # wrongly segment the title caused by sentence segmenter error.
                # We add more sentences as definition
                for offset in range(1, 10):
                    definition = documents[: def_num + offset]
                    definition = [word for sen in definition for word in sen]
                    if len(definition) >= 30:
                        break
            context_offset = def_num + offset

            for context_id in range(len(documents)):
                context_sentence = documents[context_id]
                context_sentence = [x.lower() for x in context_sentence if
                                    x not in nltk_stopwords and x not in string.punctuation]
                documents_clean.append(context_sentence)
            e2p_new = {
                "offset": context_offset,
                # 'documents': documents,
                'definition_clean': [word for sen in documents_clean[:context_offset] for word in sen],
                "documents_clean": documents_clean,
            }
            output_queue.put((target, e2p_new))
        else:
            logging.info('Quit worker')
            break

def do_retrieval(args, len_data, bm25, jobs_queue, output_queue, time_start):
    while True:
        i_t, batch = jobs_queue.get()
        if i_t is not None:
            if i_t % 1000 == 0:
                passed_time = time.time() - time_start
                estimate_time = passed_time / (i_t + 1) * (len_data - i_t)
                logging.info("do_retrieval at No. %d batch, time cost is %f. Estimated remaining time is %f", i_t, passed_time, estimate_time)

            bm25_scores = bm25.get_batch_scores(args, batch)
            output_queue.put((i_t, bm25_scores))
        else:
            logging.info('Quit worker')
            break

def reduce_process(output_queue, all_features):
    '''
    write the processed data to the e2p_and_e2c from output_queue
    :param output_queue:
    :param all_features:
    :return:
    '''
    while True:
        target, doc = output_queue.get()
        if target is not None:
            # logging.info("do_retrieval at {}".format(target))
            all_features[target] = doc
        else:
            logging.info('Quit Reducer')
            break

def filter_bm25_multiprocess(args):
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2c_addr = os.path.join(args.file, 'processed', "e2c.bz2")

    e2p_bmcorpus_save_dir = os.path.join(args.file, 'processed', 'e2p_' + 'bm25_corpus' + '.bz2')
    if os.path.exists(e2p_bmcorpus_save_dir):
        e2p_new = read_bz2(e2p_bmcorpus_save_dir)
    else:
        maxsize = 10000
        # output queue
        output_queue = Queue(maxsize=maxsize)
        jobs_queue = Queue(maxsize=maxsize)
        manager = Manager()
        all_features = manager.dict()
        process_count = max(1, args.processes)
        number_reducer = process_count // 2
        worker_count = process_count
        e2p = read_bz2(e2p_addr)
        len_e2p = len(e2p)
        logging.info("length of e2p: %d", len_e2p)
        reduces = []
        # reduce job that sorts and prints output
        for i in range(number_reducer):
            reduce = Process(target=reduce_process,
                             args=(output_queue, all_features))
            reduce.start()
            reduces.append(reduce)
        # initialize jobs queue

        # start worker processes
        logging.info("Using %d worker processes for preparing corpus.", process_count)
        time_start = time.time()
        workers = []

        for i in range(worker_count - number_reducer):
            worker = Process(target=prepare_bm25_corpus,
                             args=(args, len_e2p, jobs_queue, output_queue, time_start))
            worker.daemon = True  # only live while parent process lives
            worker.start()
            workers.append(worker)

        # Mapper process
        e2p_keys = list(e2p.keys())
        for i_t, target in enumerate(tqdm(e2p_keys, desc='assign jobs')):
            documents = e2p.pop(target)
            job = (i_t, target, documents)
            jobs_queue.put(job)

        # signal termination
        for _ in workers:
            jobs_queue.put((None, None, None))
        # wait for workers to terminate
        for w in workers:
            w.join()

        # signal end of work to reduce process
        for _ in reduces:
            output_queue.put((None, None))
        # wait for it to finish
        for r in reduces:
            r.join()

        e2p_new = dict(all_features)
        del all_features
        del e2p
        save_new_file(new_e2p=e2p_new, new_e2p_f=e2p_bmcorpus_save_dir)

    corpus = []
    corpus_index = {}
    for target in tqdm(e2p_new, desc='creating bm25 corpus'):
        target_dict = e2p_new[target]
        documents_clean = target_dict.pop("documents_clean")
        for i_d, doc_one in enumerate(documents_clean):
            index_pair = ((target, i_d), len(corpus))
            corpus_index[index_pair[0]] = index_pair[1]
            corpus.append(doc_one)

    bm25 = BM25Okapi(corpus)

    e2c = read_bz2(e2c_addr)
    len_e2c = len(e2c)
    logging.info("length of e2c: %d", len_e2c)

    # batchfy
    e2c_keys = list(e2c.keys())
    batches = []
    batch_one = []
    original_length = 0
    for i_t, target in enumerate(tqdm(e2c_keys, desc='prepare batches')):
        if len(batch_one) == args.batch_size:
            batches.append(batch_one)
            batch_one = []
        documents = e2c[target]
        original_length += len(documents)
        query = e2p_new[target]['definition_clean']
        documents_retrieval_ids = [corpus_index[(x[0], x[1])] for x in documents]
        batch_one.append((target, query, documents_retrieval_ids))
    if len(batch_one) != 0:
        batches.append(batch_one)

    len_batch = len(batches)
    logging.info("length of bm25 batches: %d.", len_batch)
    save_file_num = 20
    total_contexts = 0
    # for save_id in range(save_file_num):
    # start = len_batch // save_file_num * save_id
    # end = len_batch // save_file_num * (save_id + 1)
    # if save_id == save_file_num - 1:
    #     end = len_batch
    # batch_one_save = batches[start: end]
    # len_batch_one = len(batches)
    maxsize = 10000
    # output queue
    output_queue = Queue(maxsize=maxsize)
    jobs_queue = Queue(maxsize=maxsize)
    with Manager() as manager:
    # manager = Manager()
        all_features = manager.dict()
        process_count = max(1, args.processes)
        number_reducer = max(1, process_count // 5)
        worker_count = process_count
        logging.info("Using %d worker processes for bm25 retrieval.", process_count)

        # start worker processes
        time_start = time.time()
        reduces = []
        for i in range(number_reducer):
            reduce = Process(target=reduce_process,
                             args=(output_queue, all_features))
            reduce.start()
            reduces.append(reduce)

        workers = []
        for i in range(worker_count - number_reducer):
            worker = Process(target=do_retrieval,
                             args=(args, len_batch, bm25, jobs_queue, output_queue, time_start))
            # worker.daemon = True  # only live while parent process lives
            worker.start()
            workers.append(worker)

        # Mapper process
        for i_t, batch in enumerate(tqdm(batches, desc='assign jobs')):
            job = (i_t, batch)
            jobs_queue.put(job)

        # signal termination
        for _ in workers:
            jobs_queue.put((None, None))
        # wait for workers to terminate
        for w in workers:
            w.join()

        # signal end of work to reduce process
        for _ in reduces:
            output_queue.put((None, None))
        # wait for it to finish
        for r in reduces:
            r.join()

        all_features = dict(all_features)
        all_keys = list(all_features.keys())
        e2c_new = {}

        for i_t in tqdm(all_keys, desc='prepare new corpus'):
            items = all_features.pop(i_t)
            for i in range(len(items)):
                target = items[i][0]
                scores_one = torch.tensor(items[i][1])
                documents = e2c.pop(target)
                sorted_indexs = torch.sort(scores_one, descending=True)[1]
                assert len(documents) == sorted_indexs.size(0)
                sorted_indexs = sorted_indexs[:int((len(documents) + 1) * args.bm25_threshold)].numpy().tolist()
                new_documents = [documents[i] + [float(scores_one[i].cpu())] for i in sorted_indexs]
                e2c_new[target] = new_documents
                total_contexts += len(new_documents)

        logging.info("finish in {} seconds.".format(time.time() - time_start))
        logging.info("From total {} pairs, we remove entity anchors with with top {} bm25 scores for each entity. Now we have {} context-query pairs".format(original_length, args.bm25_threshold, total_contexts))
        e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + 'bm25' + '.bz2')
        # e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + '{}-{}'.format(args.fre, args.k) + '.bz2')
        save_new_file(e2c_new, e2c_save_dir)

def sim_matrix(a, b, eps=1e-8):
    """
    added eps for numerical stability
    """
    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
    return sim_mt
# Tokenize input texts


def norm(a, eps=1e-8):
    """
    added eps for numerical stability
    """
    a_n = a.norm(dim=1)[:, None]
    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
    return a_norm


def prepare_simcse_corpus(args, len_data, jobs_queue, output_queue, time_start):
    while True:
        i_t, target, documents = jobs_queue.get()
        if target is not None:
            if i_t % 100000 == 0:
                passed_time = time.time() - time_start
                estimate_time = passed_time / (i_t + 1) * (len_data - i_t)
                logging.info("prepare_bm25_corpus at No. %d entity '%s', time cost is %f. Estimated remaining time is %f", i_t, target, passed_time, estimate_time)
            def_num = args.def_num
            definition = documents[: def_num]
            definition = [word for sen in definition for word in sen]  # linearize
            offset = 0
            if len(definition) < 30:
                # wrongly segment the title caused by sentence segmenter error.
                # We add more sentences as definition
                for offset in range(1, 10):
                    definition = documents[: def_num + offset]
                    definition = [word for sen in definition for word in sen]
                    if len(definition) >= 30:
                        break
            context_offset = def_num + offset
            e2p_new = {
                'definition': " ".join(definition),
                "documents": [" ".join(x) for x in documents]
            }
            output_queue.put((target, e2p_new))
        else:
            logging.info('Quit worker')
            break

def filter_simcse(args):
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2c_addr = os.path.join(args.file, 'processed', "e2c.bz2")
    if args.local_rank != -1:
        e2p_new_f = os.path.join(args.file, 'processed', 'e2p_' + 'simcse_corpus.' + str(args.local_rank) + '.bz2')
        e2c_new_f = os.path.join(args.file, 'processed', 'e2c_' + 'simcse_corpus.' + str(args.local_rank) + '.bz2')
        e2p_new = read_bz2(e2p_new_f)
        e2c = read_bz2(e2c_new_f)
    else:
        # e2p_new = read_bz2(e2p_simcsecorpus_save_dir)
        assert args.local_rank == -1
        e2p = read_bz2(e2p_addr)
        len_e2p = len(e2p)
        logging.info("length of e2p: %d", len_e2p)
        e2c, n_pairs_all = read_bz2(e2c_addr, return_length=True)
        len_e2c = len(e2c)
        logging.info("length of e2c: %d", len_e2c)
        # start worker processes
        time_start = time.time()
        maxsize = 10000
        # output queue
        output_queue = Queue(maxsize=maxsize)
        jobs_queue = Queue(maxsize=maxsize)
        manager = Manager()
        all_features = manager.dict()
        process_count = max(1, args.processes)
        logging.info("Using %d worker processes for preparing corpus.", process_count)
        number_reducer = max(1, min(5, process_count // 3))
        worker_count = process_count
        reduces = []
        # reduce job that sorts and prints output
        for i in range(number_reducer):
            reduce = Process(target=reduce_process,
                             args=(output_queue, all_features))
            reduce.start()
            reduces.append(reduce)

        workers = []
        for i in range(worker_count - number_reducer):
            worker = Process(target=prepare_simcse_corpus,
                             args=(args, len_e2p, jobs_queue, output_queue, time_start))
            worker.daemon = True  # only live while parent process lives
            worker.start()
            workers.append(worker)

        # Mapper process
        e2p_keys = list(e2p.keys())
        for i_t, target in enumerate(tqdm(e2p_keys, desc='assign jobs')):
            documents = e2p.pop(target)
            job = (i_t, target, documents)
            jobs_queue.put(job)

        # signal termination
        for _ in workers:
            jobs_queue.put((None, None, None))
        # wait for workers to terminate
        for w in workers:
            w.join()

        # signal end of work to reduce process
        for _ in reduces:
            output_queue.put((None, None))
        # wait for it to finish
        for r in reduces:
            r.join()

        e2p_new = dict(all_features)

        del e2p
        load = math.ceil(n_pairs_all / args.n_gpu)
        if args.local_rank == -1:
            e2c_keys = list(e2c.keys())
            e2c_keys_one = []
            load_one = 0
            end = 0
            for i in range(args.n_gpu):
                start = end
                while load_one < (i + 1) * load:
                    load_one += len(e2c[e2c_keys[end]])
                    end += 1
                e2p_new_f = os.path.join(args.file, 'processed', 'e2p_' + 'simcse_corpus.' + str(i) + '.bz2')
                e2c_new_f = os.path.join(args.file, 'processed', 'e2c_' + 'simcse_corpus.' + str(i) + '.bz2')
                keep_keys = {}
                n_pairs = 0
                if i == args.n_gpu - 1:
                    end = len_e2c
                e2c_keys_one = e2c_keys[start:end]
                e2c_new_one = {}
                for target in e2c_keys_one:
                    context = e2c[target]
                    e2c_new_one[target] = context
                    keep_keys[target] = 1
                    for item in context:
                        keep_keys[item[0]] = 1
                        n_pairs += 1

                e2p_new_one = {}
                for key in keep_keys:
                    e2p_new_one[key] = e2p_new[key]
                save_new_file(new_e2p=e2p_new_one, new_e2p_f=e2p_new_f, new_e2c=e2c_new_one, new_e2c_f=e2c_new_f)
                logging.info("There are {} query-context pairs in No. {} sub file".format(n_pairs, i))
        logging.info('please run the script the second time with CUDA_VISIBLE_DIVICES={n_gpu} python -m torch.distributed.launch --nproc_per_node={n_gpu} wiki_sampling.py --batch_size {batch_size} --file {file} --method simcse --fp16')
        return

    # batchfy
    e2c_keys = list(e2c.keys())
    batches = []
    batch_one = []
    num_instance = 0
    for i_t, target in enumerate(tqdm(e2c_keys, desc='prepare batch')):
        documents = e2c[target]
        query = e2p_new[target]['definition']
        context = [e2p_new[x[0]]['documents'][x[1]] for x in documents]
        n_batch_one = math.ceil(len(context) / args.batch_size)
        for i_b in range(n_batch_one):
            if num_instance >= args.batch_size:
                batches.append(batch_one)
                num_instance = 0
                batch_one = []
            context_one = context[args.batch_size * i_b: args.batch_size * (i_b + 1)]
            batch_one.append((target, query, context_one, args.batch_size * i_b, len(context_one)))
            num_instance = num_instance + len(context_one) + 1
    if len(batch_one) != 0:
        batches.append(batch_one)
    del e2p_new

    time_start = time.time()
    # tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens', cache_dir=args.cache_dir)
    # model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens', cache_dir=args.cache_dir)
    tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large", cache_dir=args.cache_dir)
    model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large", cache_dir=args.cache_dir)
    model.to(args.device)
    if args.fp16:
        try:
            import apex
            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model = amp.initialize(model, opt_level=args.fp16_opt_level)
        # Distributed training (should be after apex fp16 initialization)
        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
            )
    model.eval()
    all_results = []
    for i in tqdm(range(len(batches)), desc='compute cos similarity'):
        batch_one = batches.pop(0)
        all_target = [x[0] for x in batch_one]
        query_text = [x[1] for x in batch_one]
        context_text = [y for x in batch_one for y in x[2]]
        context_offset = [x[3] for x in batch_one]
        context_length = [x[4] for x in batch_one]

        inputs = tokenizer(query_text + context_text, padding=True, truncation=True, return_tensors="pt", max_length=512)
        inputs = inputs.to(args.device)
        # Get the embeddings
        with torch.no_grad():
            embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
            query_state = embeddings[: len(query_text)]
            context_state = embeddings[len(query_text):]
            cos_sim = sim_matrix(query_state, context_state)
            assert sum(context_length) == cos_sim.size(1)
        all_results.append((all_target, context_offset, context_length, cos_sim))

    all_results_dict = {}
    for _ in tqdm(range(len(all_results)), desc='results list to dict'):
        quad = all_results.pop(0)
        all_target, offset_list, length_list, cos_sim = quad
        current_offset = 0
        for i_t, target in enumerate(all_target):
            context_offset = offset_list[i_t]
            context_length = length_list[i_t]
            cos_sim_one = cos_sim[i_t, current_offset: current_offset + context_length]
            current_offset += context_length
            if target not in all_results_dict:
                all_results_dict[target] = (context_offset, cos_sim_one)
            else:
                assert all_results_dict[target][0] < context_offset
                new_cos_sim = torch.cat([all_results_dict[target][1], cos_sim_one])
                all_results_dict[target] = (context_offset, new_cos_sim)


    e2c_new = {}
    total_contexts = 0
    original_length = 0
    all_results_keys = list(all_results_dict.keys())
    for target in tqdm(all_results_keys):
        _, cos_sim_one = all_results_dict.pop(target)
        documents = e2c.pop(target)
        original_length += len(documents)
        assert len(documents) == cos_sim_one.size(0)
        sorted_indexs = torch.sort(cos_sim_one, descending=True)[1]
        sorted_indexs = sorted_indexs[:int((len(documents) + 1) * args.simcse_threshold)].cpu().numpy().tolist()
        new_documents = [documents[i] + [float(cos_sim_one[i].cpu())] for i in sorted_indexs]
        e2c_new[target] = new_documents
        total_contexts += len(new_documents)

    logging.info("finish in {} seconds.".format(time.time() - time_start))
    logging.info("From total {} pairs, we keep top {} simCSE cosine scores pairs. Now we have {} context-query pairs".format( original_length, args.simcse_threshold, total_contexts))
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + 'simcse.' + str(args.local_rank) + '.bz2')
    # e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + '{}-{}'.format(args.fre, args.k) + '.bz2')
    save_new_file(e2c_new, e2c_save_dir)

def initialize(X, num_clusters):
    """
    initialize cluster centers
    :param X: (torch.tensor) matrix
    :param num_clusters: (int) number of clusters
    :return: (np.array) initial state
    """
    num_samples = len(X)
    indices = np.random.choice(num_samples, num_clusters, replace=False)
    initial_state = X[indices]
    return initial_state

def kmeans(
        X,
        num_clusters,
        distance='euclidean',
        tol=1e-4,
        device=torch.device('cpu')
):
    """
    copied from kmeans_pytorch, fix the bug of empty cluster
    perform kmeans
    :param X: (torch.tensor) matrix
    :param num_clusters: (int) number of clusters
    :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
    :param tol: (float) threshold [default: 0.0001]
    :param device: (torch.device) device [default: cpu]
    :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
    """
    # print(f'running k-means on {device}..')

    if distance == 'euclidean':
        pairwise_distance_function = pairwise_distance
    elif distance == 'cosine':
        pairwise_distance_function = pairwise_cosine
    else:
        raise NotImplementedError

    # convert to float
    X = X.float()

    # transfer to device
    X = X.to(device)

    # initialize
    initial_state = initialize(X, num_clusters)

    iteration = 0
    while True:
        dis = pairwise_distance_function(X, initial_state)
        choice_cluster = torch.argmin(dis, dim=1)
        cluster_count = {index: 0 for index in range(num_clusters)}
        for index in choice_cluster.cpu().numpy().tolist():
            cluster_count[index] += 1
        empty_clusters = [index for index in cluster_count if cluster_count[index] == 0]
        while empty_clusters != []:
            empty_index = empty_clusters.pop(0)
            dis_to_center = dis[range(dis.size(0)), choice_cluster]
            new_center_id_ranked = torch.sort(dis_to_center, descending=True)[1].cpu().numpy().tolist()
            while new_center_id_ranked != []:
                new_center_id = new_center_id_ranked.pop(0)
                cluster_id_to_change = int(choice_cluster[new_center_id])
                if cluster_count[cluster_id_to_change] > 1:
                    cluster_count[cluster_id_to_change] -= 1
                    break
            initial_state[empty_index] = X[new_center_id]
            choice_cluster[new_center_id] = empty_index
            dis[new_center_id] = pairwise_distance_function(X[new_center_id].view(1, -1), initial_state)
        for index in range(num_clusters):
            try:
                assert index in choice_cluster
            except:
                print('iteration', iteration)
                print("choice", choice_cluster)
                print("index", index)
                print('dis', dis)
                raise NotImplementedError
        initial_state_pre = initial_state.clone()
        for index in range(num_clusters):
            selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
            selected = torch.index_select(X, 0, selected)
            initial_state[index] = selected.mean(dim=0)
            try:
                assert not initial_state[index][0].isnan()
            except:
                print("choice", choice_cluster)
                print("index", index)
                print("state", initial_state[index])
                raise NotImplementedError

        center_shift = torch.sum(
            torch.sqrt(
                torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
            ))

        # increment iteration
        iteration = iteration + 1

        if center_shift ** 2 < tol:
            break

    return choice_cluster.cpu(), initial_state.cpu()

def kmeans_predict(
        X,
        cluster_centers,
        distance='euclidean',
        device=torch.device('cpu')
):
    """
    predict using cluster centers
    :param X: (torch.tensor) matrix
    :param cluster_centers: (torch.tensor) cluster centers
    :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
    :param device: (torch.device) device [default: 'cpu']
    :return: (torch.tensor) cluster ids
    """
    print(f'predicting on {device}..')

    if distance == 'euclidean':
        pairwise_distance_function = pairwise_distance
    elif distance == 'cosine':
        pairwise_distance_function = pairwise_cosine
    else:
        raise NotImplementedError

    # convert to float
    X = X.float()

    # transfer to device
    X = X.to(device)

    dis = pairwise_distance_function(X, cluster_centers)
    choice_cluster = torch.argmin(dis, dim=1)

    return choice_cluster.cpu()

def pairwise_distance(data1, data2, device=torch.device('cpu')):
    # transfer to device
    data1, data2 = data1.to(device), data2.to(device)

    # N*1*M
    A = data1.unsqueeze(dim=1)

    # 1*N*M
    B = data2.unsqueeze(dim=0)

    dis = (A - B) ** 2.0
    # return N*N matrix for pairwise distance
    dis = dis.sum(dim=-1).squeeze()
    return dis

def pairwise_cosine(data1, data2, device=torch.device('cpu')):
    # transfer to device
    data1, data2 = data1.to(device), data2.to(device)

    # N*1*M
    A = data1.unsqueeze(dim=1)

    # 1*N*M
    B = data2.unsqueeze(dim=0)

    # normalize the points  | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
    A_normalized = A / A.norm(dim=-1, keepdim=True)
    B_normalized = B / B.norm(dim=-1, keepdim=True)

    cosine = A_normalized * B_normalized

    # return N*N matrix for pairwise distance
    cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
    return cosine_dis

def cluster_simcse(args):
    if args.local_rank != -1:
        e2c_new_f = os.path.join(args.file, 'processed', 'e2c_' + 'cluster_corpus.' + str(args.local_rank) + '.bz2')
        e2c, original_length = read_bz2(e2c_new_f, return_length=True)
    else:
        e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
        e2c_addr = os.path.join(args.file, 'processed', "e2c.bz2")
        e2p = read_bz2(e2p_addr)
        e2c = read_bz2(e2c_addr)
        e2c_keys = list(e2c.keys())
        e2c_new = {}
        total_context = 0
        for i_t, target in enumerate(tqdm(e2c_keys, desc='prepare batch')):
            context = e2c[target]
            orignal_length = len(context)
            context_text_dict = {}
            context_non_overlap = []
            while context != []:
                x = context.pop(0)
                context_text_one = " ".join(e2p[x[0]][x[1]])
                if context_text_one not in context_text_dict:
                    context_text_dict[context_text_one] = 1
                    context_non_overlap.append(x)
            if len(context_non_overlap) < args.fre:
                continue
            context_non_overlap = random.sample(context_non_overlap, k=min(len(context_non_overlap), args.max_context))
            context_non_overlap_text = [" ".join(e2p[x[0]][x[1]]) for x in context_non_overlap] # make an initial sampling to avoid large context size
            assert len(context_non_overlap) == len(context_non_overlap_text)
            total_context += len(context_non_overlap)
            item_new = {
                'id': context_non_overlap,
                'text': context_non_overlap_text,
            }
            e2c_new[target] = item_new
        del e2p
        load = math.ceil(total_context / args.n_gpu)
        n_context = 0
        if args.local_rank == -1:
            e2c_keys = list(e2c_new.keys())

            for i_rank in range(args.n_gpu):
                if i_rank != args.n_gpu - 1:
                    keep_keys = {}
                    while n_context <= (i_rank + 1) * load:
                        target = e2c_keys.pop(0)
                        context_len_one = len(e2c_new[target]['id'])
                        n_context += context_len_one
                        keep_keys[target] = 1
                    e2c_new_one = {}
                    for keep_target in keep_keys:
                        e2c_new_one[keep_target] = e2c_new.pop(keep_target)
                    logging.info(
                        "There are {} query-context pairs in No. {} sub file".format(n_context - load * (i_rank), i_rank))
                else:
                    e2c_new_one = e2c_new
                    logging.info(
                        "There are {} query-context pairs in No. {} sub file".format(load * args.n_gpu - n_context,
                                                                                     i_rank))
                e2c_new_one_f = os.path.join(args.file, 'processed', 'e2c_' + 'cluster_corpus.' + str(i_rank) + '.bz2')
                save_new_file(new_e2c=e2c_new_one, new_e2c_f=e2c_new_one_f)


        logging.info('please run the script the second time with CUDA_VISIBLE_DIVICES={n_gpu} python -m torch.distributed.launch --nproc_per_node={n_gpu} wiki_sampling.py --batch_size {batch_size} --file {file} --method cluster --fp16')
        return
    # batchfy
    e2c_cluster = {}
    batches = []
    batch_one = []
    e2c_keys = list(e2c.keys())
    for i_t, target in enumerate(tqdm(e2c_keys, desc='prepare batch')):
        item = e2c[target]
        context = item['id']
        context_text = item['text']
        batch_one.append((target, context, context_text))
        if len(batch_one) >= args.batch_size:
            batches.append(batch_one)
            batch_one = []
    if len(batch_one) != 0:
        batches.append(batch_one)

    time_start = time.time()
    # tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens', cache_dir=args.cache_dir)
    # model = AutoModel.from_pretrained('sentence-transformers/bert-base-nli-mean-tokens', cache_dir=args.cache_dir)
    tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-roberta-large", cache_dir=args.cache_dir)
    model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-roberta-large", cache_dir=args.cache_dir)
    model.to(args.device)
    if args.fp16:
        try:
            import apex
            apex.amp.register_half_function(torch, "einsum")
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model = amp.initialize(model, opt_level=args.fp16_opt_level)
        # Distributed training (should be after apex fp16 initialization)
        if args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
            )
    model.eval()

    for i in tqdm(range(len(batches)), desc='get diverse context'):
        batch_one = batches.pop(0)
        all_target = [x[0] for x in batch_one]
        all_context_id = [x[1] for x in batch_one]
        context_text = [y for x in batch_one for y in x[2]]

        inputs = tokenizer(context_text, padding=True, truncation=True, return_tensors="pt", max_length=512)
        inputs = inputs.to(args.device)
        # Get the embeddings
        with torch.no_grad():
            all_context_state = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output
            # all_context_state = norm(embeddings)
            end = 0
            for i_t, target in enumerate(all_target):
                context_id = all_context_id[i_t]
                start = end
                end = end + len(context_id)
                context_state = all_context_state[start: end]
                assert context_state.size(0) == len(context_id)
                cluster_ids_x, cluster_centers = kmeans(X=context_state, num_clusters=args.k, distance='cosine', device=args.device)
                context_id_new = []
                for i_cluster in range(args.k):
                    index_one = torch.where(cluster_ids_x == i_cluster)[0]
                    assert index_one.size(0) >= 1
                    index_selected = random.sample(index_one.cpu().numpy().tolist(), 1)[0]
                    context_id_new.append(context_id[index_selected])
                e2c_cluster[target] = context_id_new
    total_contexts = 0
    for target in e2c_cluster:
        total_contexts += len(e2c_cluster[target])
    logging.info("finish in {} seconds.".format(time.time() - time_start))
    logging.info("From total {} pairs, we use {} clusters to measure the diversity. After filtering, we have {} context-query pairs".format(original_length, args.k, total_contexts))
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + 'cluster-' + str(args.k) + '.' + str(args.local_rank) + '.bz2')
    # e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + '{}-{}'.format(args.fre, args.k) + '.bz2')
    save_new_file(e2c_cluster, e2c_save_dir)

def merge_cluster(args):
    e2c_cluster_addr = [os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + "." + str(i) + '.bz2') for i in range(args.n_gpu)]
    e2c_mix = {}
    original_length = {}

    for e2c_cluster_addr_one in e2c_cluster_addr:
        original_length[e2c_cluster_addr_one] = 0
        e2c_cluster_one = read_bz2(e2c_cluster_addr_one)
        for target in e2c_cluster_one:
            context = e2c_cluster_one[target]
            original_length[e2c_cluster_addr_one] += len(context)
            e2c_mix[target] = context

    mix_length = 0
    for target in e2c_mix:
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in all bm25 and simcse filtering result is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After mixing the query-context in all bm25 and simcse file, we get {} pairs.".format(mix_length))

    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir)

def filter_full_random(args):

    e2c_addr = os.path.join(args.file, 'processed', "e2c.bz2")
    e2c = read_bz2(e2c_addr)
    e2c_mix = {}
    original_length = {e2c_addr:0}

    for target in e2c:
        context = e2c[target]
        original_length[e2c_addr] += len(context)
        if len(context) < args.fre:
            continue

        new_context = [tuple(x[:4]) for x in context]
        e2c_mix[target] = new_context
    del e2c

    mix_length = 0
    for target in e2c_mix:
        context = e2c_mix[target]
        e2c_mix[target] = random.sample(context, k=min(len(context), args.k))
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in e2c file is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After sampling the query-context in e2c file, we get {} pairs.".format(mix_length))

    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir)

def filter_mix_30(args):

    e2c_bm25_addr = os.path.join(args.file, 'processed', "e2c_bm25_50.bz2")
    e2c_simcse_addr = [os.path.join(args.file, 'processed', 'e2c_' + 'simcse_50.' + str(i) + '.bz2') for i in range(args.n_gpu)]
    e2c_bm25 = read_bz2(e2c_bm25_addr)
    e2c_mix = {}
    original_length = {e2c_bm25_addr:0}

    for target in e2c_bm25:
        context = e2c_bm25[target]
        original_length[e2c_bm25_addr] += len(context)
        if len(context) < 3: # the length of 50% of the context < 3 is equivalent to the length of 100% of the context < 5
            continue
        if args.mix_threshold != 1:
            keep_index = math.ceil(len(context) * args.mix_threshold)
            context = context[:keep_index]
        new_context = [tuple(x[:4]) for x in context]
        e2c_mix[target] = new_context
    del e2c_bm25

    for e2c_simcse_addr_one in e2c_simcse_addr:
        original_length[e2c_simcse_addr_one] = 0
        e2c_simcse_one = read_bz2(e2c_simcse_addr_one)
        for target in e2c_simcse_one:
            context = e2c_simcse_one[target]
            original_length[e2c_simcse_addr_one] += len(context)
            if len(context) < 3:  # the length of 50% of the context < 3 is equivalent to the length of 100% of the context < 5
                continue
            if args.mix_threshold != 1:
                keep_index = math.ceil(len(context) * args.mix_threshold)
                context = context[:keep_index]
            new_context = [tuple(x[:4]) for x in context]
            if target in e2c_mix:
                e2c_mix[target] = list(set(e2c_mix[target] + new_context))
            else:
                e2c_mix[target] = new_context
        del e2c_simcse_one
    mix_length = 0
    for target in e2c_mix:
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in all bm25 and simcse filtering result is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After mixing the query-context in all bm25 and simcse file, we get {} pairs.".format(mix_length))

    keep_keys = {}
    e2c_mix_keys = list(e2c_mix.keys())
    for key in e2c_mix_keys:
        context = e2c_mix[key]
        keep_keys[key] = 1
        for x in context:
            keep_keys[x[0]] = 1
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2p = read_bz2(e2p_addr)
    e2p_mix = {}
    for key in keep_keys:
        e2p_mix[key] = e2p.pop(key)
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir, e2p_mix, e2p_save_dir)

def filter_mix_50_random_low5(args):

    e2c_bm25_addr = os.path.join(args.file, 'processed', "e2c_bm25_50.bz2")
    e2c_simcse_addr = [os.path.join(args.file, 'processed', 'e2c_' + 'simcse_50.' + str(i) + '.bz2') for i in range(args.n_gpu)]
    e2c_bm25 = read_bz2(e2c_bm25_addr)
    e2c_mix = {}
    original_length = {e2c_bm25_addr:0}

    for target in e2c_bm25:
        context = e2c_bm25[target]
        original_length[e2c_bm25_addr] += len(context)
        if len(context) < 3: # the length of 50% of the context < 3 is equivalent to the length of 100% of the context < 5
            continue

        new_context = [tuple(x[:4]) for x in context]
        e2c_mix[target] = new_context
    del e2c_bm25

    for e2c_simcse_addr_one in e2c_simcse_addr:
        original_length[e2c_simcse_addr_one] = 0
        e2c_simcse_one = read_bz2(e2c_simcse_addr_one)
        for target in e2c_simcse_one:
            context = e2c_simcse_one[target]
            original_length[e2c_simcse_addr_one] += len(context)
            if len(context) < 3:  # the length of 50% of the context < 3 is equivalent to the length of 100% of the context < 5
                continue
            new_context = [tuple(x[:4]) for x in context]
            if target in e2c_mix:
                e2c_mix[target] = list(set(e2c_mix[target] + new_context))
            else:
                e2c_mix[target] = new_context
        del e2c_simcse_one
    mix_length = 0
    for target in e2c_mix:
        context = e2c_mix[target]
        e2c_mix[target] = random.sample(context, k=min(len(context), args.k))
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in all bm25 and simcse filtering result is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After mixing the query-context in all bm25 and simcse file, we get {} pairs.".format(mix_length))

    keep_keys = {}
    e2c_mix_keys = list(e2c_mix.keys())
    for key in e2c_mix_keys:
        context = e2c_mix[key]
        keep_keys[key] = 1
        for x in context:
            keep_keys[x[0]] = 1
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2p = read_bz2(e2p_addr)
    e2p_mix = {}
    for key in keep_keys:
        e2p_mix[key] = e2p.pop(key)
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir, e2p_mix, e2p_save_dir)

def filter_mix_50_random(args):

    e2c_bm25_addr = os.path.join(args.file, 'processed', "e2c_bm25_50.bz2")
    e2c_simcse_addr = [os.path.join(args.file, 'processed', 'e2c_' + 'simcse_50.' + str(i) + '.bz2') for i in range(args.n_gpu)]
    e2c_bm25 = read_bz2(e2c_bm25_addr)
    e2c_mix = {}
    original_length = {e2c_bm25_addr:0}

    for target in e2c_bm25:
        context = e2c_bm25[target]
        original_length[e2c_bm25_addr] += len(context)
        if len(context) < 5: # the length of 50% of the context < 5 is equivalent to the length of 100% of the context < 10
            continue

        new_context = [tuple(x[:4]) for x in context]
        e2c_mix[target] = new_context
    del e2c_bm25

    for e2c_simcse_addr_one in e2c_simcse_addr:
        original_length[e2c_simcse_addr_one] = 0
        e2c_simcse_one = read_bz2(e2c_simcse_addr_one)
        for target in e2c_simcse_one:
            context = e2c_simcse_one[target]
            original_length[e2c_simcse_addr_one] += len(context)
            if len(context) < 5:  # the length of 50% of the context < 5 is equivalent to the length of 100% of the context < 10
                continue
            new_context = [tuple(x[:4]) for x in context]
            if target in e2c_mix:
                e2c_mix[target] = list(set(e2c_mix[target] + new_context))
            else:
                e2c_mix[target] = new_context
        del e2c_simcse_one
    mix_length = 0
    for target in e2c_mix:
        context = e2c_mix[target]
        e2c_mix[target] = random.sample(context, k=min(len(context), args.k))
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in all bm25 and simcse filtering result is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After mixing the query-context in all bm25 and simcse file, we get {} pairs.".format(mix_length))

    keep_keys = {}
    e2c_mix_keys = list(e2c_mix.keys())
    for key in e2c_mix_keys:
        context = e2c_mix[key]
        keep_keys[key] = 1
        for x in context:
            keep_keys[x[0]] = 1
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2p = read_bz2(e2p_addr)
    e2p_mix = {}
    for key in keep_keys:
        e2p_mix[key] = e2p.pop(key)
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir, e2p_mix, e2p_save_dir)

def filter_mix_50_bottom10(args):

    e2c_bm25_addr = os.path.join(args.file, 'processed', "e2c_bm25_50.bz2")
    e2c_simcse_addr = [os.path.join(args.file, 'processed', 'e2c_' + 'simcse_50.' + str(i) + '.bz2') for i in range(args.n_gpu)]
    e2c_bm25 = read_bz2(e2c_bm25_addr)
    e2c_mix = {}
    original_length = {e2c_bm25_addr:0}

    for target in e2c_bm25:
        context = e2c_bm25[target]
        original_length[e2c_bm25_addr] += len(context)
        if len(context) < args.k // 2:
            continue
        context = sorted(context, key=lambda x: x[4])[:args.k // 2]
        new_context = [tuple(x[:4]) for x in context]
        e2c_mix[target] = new_context
    del e2c_bm25

    for e2c_simcse_addr_one in e2c_simcse_addr:
        original_length[e2c_simcse_addr_one] = 0
        e2c_simcse_one = read_bz2(e2c_simcse_addr_one)
        for target in e2c_simcse_one:
            context = e2c_simcse_one[target]
            original_length[e2c_simcse_addr_one] += len(context)
            if len(context) < args.k // 2:
                continue
            context = sorted(context, key=lambda x: x[4])
            new_context = [tuple(x[:4]) for x in context]
            if target in e2c_mix:
                all_context = e2c_mix[target]
                bm25_length = len(all_context)
                assert bm25_length < args.k
                simcse_length = 0
                for i in range(len(new_context)):
                    if bm25_length + simcse_length == args.k:
                        break
                    else:
                        if new_context[i] not in all_context:
                            all_context.append(new_context[i])
                            simcse_length += 1

                e2c_mix[target] = all_context
            else:
                e2c_mix[target] = new_context[:args.k]
        del e2c_simcse_one
    mix_length = 0
    for target in e2c_mix:
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in all bm25 and simcse filtering result is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After mixing the query-context in all bm25 and simcse file, we get {} pairs.".format(mix_length))

    keep_keys = {}
    e2c_mix_keys = list(e2c_mix.keys())
    for key in e2c_mix_keys:
        context = e2c_mix[key]
        keep_keys[key] = 1
        for x in context:
            keep_keys[x[0]] = 1
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2p = read_bz2(e2p_addr)
    e2p_mix = {}
    for key in keep_keys:
        e2p_mix[key] = e2p.pop(key)
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir, e2p_mix, e2p_save_dir)

def filter_mix_50_top10(args):

    e2c_bm25_addr = os.path.join(args.file, 'processed', "e2c_bm25_50.bz2")
    e2c_simcse_addr = [os.path.join(args.file, 'processed', 'e2c_' + 'simcse_50.' + str(i) + '.bz2') for i in range(args.n_gpu)]
    e2c_bm25 = read_bz2(e2c_bm25_addr)
    e2c_mix = {}
    original_length = {e2c_bm25_addr:0}

    for target in e2c_bm25:
        context = e2c_bm25[target]
        original_length[e2c_bm25_addr] += len(context)
        if len(context) < args.k // 2:
            continue
        context = sorted(context, key=lambda x: -x[4])[:args.k // 2]
        new_context = [tuple(x[:4]) for x in context]
        e2c_mix[target] = new_context
    del e2c_bm25

    for e2c_simcse_addr_one in e2c_simcse_addr:
        original_length[e2c_simcse_addr_one] = 0
        e2c_simcse_one = read_bz2(e2c_simcse_addr_one)
        for target in e2c_simcse_one:
            context = e2c_simcse_one[target]
            original_length[e2c_simcse_addr_one] += len(context)
            if len(context) < args.k // 2:
                continue
            context = sorted(context, key=lambda x: -x[4])
            new_context = [tuple(x[:4]) for x in context]
            if target in e2c_mix:
                all_context = e2c_mix[target]
                bm25_length = len(all_context)
                assert bm25_length < args.k
                simcse_length = 0
                for i in range(len(new_context)):
                    if bm25_length + simcse_length == args.k:
                        break
                    else:
                        if new_context[i] not in all_context:
                            all_context.append(new_context[i])
                            simcse_length += 1

                e2c_mix[target] = all_context
            else:
                e2c_mix[target] = new_context[:args.k]
        del e2c_simcse_one
    mix_length = 0
    for target in e2c_mix:
        mix_length += len(e2c_mix[target])
    for addr in original_length:
        logging.info("The number of query-context in {} is {}.".format(addr, original_length[addr]))
    logging.info("Total number of query-context in all bm25 and simcse filtering result is {}.".format(sum(original_length[addr] for addr in original_length)))
    logging.info("After mixing the query-context in all bm25 and simcse file, we get {} pairs.".format(mix_length))

    keep_keys = {}
    e2c_mix_keys = list(e2c_mix.keys())
    for key in e2c_mix_keys:
        context = e2c_mix[key]
        keep_keys[key] = 1
        for x in context:
            keep_keys[x[0]] = 1
    e2p_addr = os.path.join(args.file, 'processed', "e2p.bz2")
    e2p = read_bz2(e2p_addr)
    e2p_mix = {}
    for key in keep_keys:
        e2p_mix[key] = e2p.pop(key)
    e2c_save_dir = os.path.join(args.file, 'processed', 'e2c_' + args.sample_data + '.bz2')
    e2p_save_dir = os.path.join(args.file, 'processed', 'e2p_' + args.sample_data + '.bz2')
    save_new_file(e2c_mix, e2c_save_dir, e2p_mix, e2p_save_dir)

def filter(args):
    if args.method == 'fre':
        filter_full_random(args)
    elif args.method == 'max':
        filter_max(args)
    elif args.method == "bm25":
        filter_bm25_multiprocess(args)
    elif args.method == "simcse":
        filter_simcse(args)
    elif args.method == "mix":
        filter_mix_50_random_low5(args)
        # filter_mix_50_random(args)
        # filter_mix_50_bottom10(args)
        # filter_mix_50_top10(args)
    elif args.method == "cluster":
        cluster_simcse(args)
    elif args.method == 'merge':
        merge_cluster(args)

if __name__ == "__main__":
    default_process_count = max(1, cpu_count() - 1)
    parser = argparse.ArgumentParser()
    parser.add_argument("--file", type=str, default="./en",
                        help="e2c file directory")
    parser.add_argument("--fre", type=int, default=10,
                        help="the word frequency to apply sampling")
    parser.add_argument("--k", type=int, default=10,
                        help="sample number for each entity")
    parser.add_argument("--max", type=int, default=None,
                        help="we sample instances to a max number.")
    parser.add_argument("--def_num", type=int, default=1,
                        help="we use the first def_num sentences as the definition for each entity")
    parser.add_argument("--processes", type=int, default=default_process_count,
                        help="Number of processes to use (default %(default)s)")
    parser.add_argument("--method", type=str, default='bm25', choices=['max', 'fre', 'bm25','simcse', 'mix', "cluster", 'merge'],
                        help="perform what kinds of filtering method")
    parser.add_argument("--bm25_threshold", type=float, default=1,
                        help="the rate to sample bm25 ranked data")
    parser.add_argument("--simcse_threshold", type=float, default=1,
                        help="the rate to sample bm25 ranked data")
    parser.add_argument("--mix_threshold", type=float, default=0.6,
                        help="the rate to sample mix ranked data")
    parser.add_argument("--max_context", type=int, default=50,
                        help="Some entities have large size of context. We make some random sampling before further preprocessing. ")
    parser.add_argument("--batch_size", default=3, type=int, help="Batch size per GPU/CPU for training.")
    parser.add_argument("--sample_data", type=str, default='bm25-simcse-30', help="the sampled file file for generating data")
    parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
             "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument("--no_cuda", action="store_true", help="Whether not to use CUDA when available")
    parser.add_argument(
        "--cache_dir",
        default="./cache",
        type=str,
        help="Where do you want to store the pre-trained models downloaded from huggingface.co",
    )
    args = parser.parse_args()
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend="nccl")
        args.n_gpu = torch.cuda.device_count()
    args.device = device
    filter(args)
